Skip to content

Conversation

@a-sidorova
Copy link

@a-sidorova a-sidorova commented Nov 20, 2025

Description:

  • Added the direct lowering pass for torch.aten.convolution_backward from Torch to Linalg. Enabled this pass by default. The pass generates linalg.generic ops instead of linalg.conv_<> for better lowering.
  • Removed the previous pass DecomposeAtenConvolutionBackwardOp from Torch/Transforms/DecomposeComplexOps.cpp.
  • Created new lit tests for backward convolution in the separate file convolution_backward.mlir. Also added more test cases for better test coverage.
  • Added new e2e tests for backward convolution for better test coverage.

Issue:

@a-sidorova a-sidorova force-pushed the feature/linalg_conv_bwd branch from 4f1cb20 to 8e2b616 Compare November 21, 2025 13:31
@a-sidorova a-sidorova marked this pull request as ready for review November 21, 2025 13:36
@a-sidorova
Copy link
Author

@zjgarvey hey! May I ask you to take a look when you're available? Thank you in advance for the review.

@zjgarvey zjgarvey self-requested a review November 21, 2025 19:36
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is an excellent start.

We need to keep the existing decomposition for other backends. I have a few other comments for you to look at, but that's the biggest blocker right now.

rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList,
cstFalse, cstNone);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the decomposition, E.g., TOSA and StableHLO still rely on this pattern. The purpose of the backend_legal_ops option in torch-decompose-complex-ops is specifically to prevent selected decomposition patterns.

SmallVector<int64_t> weightFlipDims;
weightFlipDims.reserve(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
weightFlipDims.push_back(spatialStartDimIdx + i);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the weight shape is static at index i, and the dim size is 1 there, don't add to the flip. We definitely see a lot of 1x1 filter convs and the noop flip doesn't get folded easily IIRC.

createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy);
gradOutputSliced = tensor::InsertSliceOp::create(
rewriter, loc,
torch_to_linalg::removeSizeInformation(rewriter, loc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the size info?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe "sliced" is a misleading name. Scattered? Or something generic like "Modified" since you are just padding when stride == 1.

createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy);
SmallVector<ReassociationIndices> gradWeightCollapseIndices;
if (isGroupedConvBwd) {
auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the init just be made on the expanded shape here (instead of expanding the init)? This probably gets folded, but I think it would be better to generate simpler IR when possible.

// `c` is the input channel dimension, `f` is the output channel
// dimension, `o` is the input spatial dimension, `k` is the kernel
// dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the
// gradient of the output tensor. `dLdx` is the data-gradient tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to mention that dLdy is the stride/padding modified grad output tensor here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And that w is flipped along spatial dims.

}

static linalg::GenericOp
createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a util for this already like "createReductionGeneric` or something. In any case, might be good to call this something a little more specific (pun intended).

Comment on lines +2061 to +2277
if (!isGrouped) {
if (numSpatialDims == 1) {
AffineExpr n, c, o, f, k;
bindDims(context, n, c, o, f, k);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o};
SmallVector<AffineExpr> weiExprs = {f, c, k};
SmallVector<AffineExpr> outExprs = {n, c, o};
indexingMaps = {AffineMap::get(5, 0, goExprs, context),
AffineMap::get(5, 0, weiExprs, context),
AffineMap::get(5, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr n, c, oh, ow, f, kh, kw;
bindDims(context, n, c, oh, ow, f, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow};
SmallVector<AffineExpr> weiExprs = {f, c, kh, kw};
SmallVector<AffineExpr> outExprs = {n, c, oh, ow};
indexingMaps = {AffineMap::get(7, 0, goExprs, context),
AffineMap::get(7, 0, weiExprs, context),
AffineMap::get(7, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction,
IT::reduction};
} else {
AffineExpr n, c, od, oh, ow, f, kd, kh, kw;
bindDims(context, n, c, od, oh, ow, f, kd, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh,
d2 * kw + ow};
SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw};
SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow};
indexingMaps = {AffineMap::get(9, 0, goExprs, context),
AffineMap::get(9, 0, weiExprs, context),
AffineMap::get(9, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction, IT::reduction};
}
} else {
if (numSpatialDims == 1) {
AffineExpr n, g, cg, o, fg, k;
bindDims(context, n, g, cg, o, fg, k);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, k};
SmallVector<AffineExpr> outExprs = {n, g, cg, o};
indexingMaps = {AffineMap::get(6, 0, goExprs, context),
AffineMap::get(6, 0, weiExprs, context),
AffineMap::get(6, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr n, g, cg, oh, ow, fg, kh, kw;
bindDims(context, n, g, cg, oh, ow, fg, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh,
d1 * kw + ow};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw};
SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow};
indexingMaps = {AffineMap::get(8, 0, goExprs, context),
AffineMap::get(8, 0, weiExprs, context),
AffineMap::get(8, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction};
} else {
AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw;
bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> goExprs = {
n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw};
SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow};
indexingMaps = {AffineMap::get(10, 0, goExprs, context),
AffineMap::get(10, 0, weiExprs, context),
AffineMap::get(10, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction, IT::reduction,
IT::reduction};
}
}
}

static void initIndexingMapsAndIteratorTypesForWeightBwd(
OpBuilder &rewriter, MLIRContext *context, bool isGrouped,
int numSpatialDims, const SmallVector<int64_t> &strideInts,
const SmallVector<int64_t> &dilationInts,
SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) {
// To calculate convolution backward-weight, we use generic operation.
// The generic operation is a generalization of the convolution operation
// that can handle any number of spatial dimensions.
// The generic operation is defined as follows:
// ```
// dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o]
// for n in range(batch_size) for o in range(output_spatial_dims))
// ```
// where `n` is the batch dimension, `g` is the group dimension,
// `c` is the input channel dimension, `f` is the output channel
// dimension, `o` is the output spatial dimension, `k` is the kernel
// dimension, `d0` is dilation and `s0` is stride. `x` is the input
// tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the
// weight-gradient tensor.
if (!isGrouped) {
if (numSpatialDims == 1) {
AffineExpr f, c, k, n, o;
bindDims(context, f, c, k, n, o);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o};
SmallVector<AffineExpr> goExprs = {n, f, o};
SmallVector<AffineExpr> outExprs = {f, c, k};
indexingMaps = {AffineMap::get(5, 0, inExprs, context),
AffineMap::get(5, 0, goExprs, context),
AffineMap::get(5, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr f, c, kh, kw, n, oh, ow;
bindDims(context, f, c, kh, kw, n, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh,
d1 * kw + s1 * ow};
SmallVector<AffineExpr> goExprs = {n, f, oh, ow};
SmallVector<AffineExpr> outExprs = {f, c, kh, kw};
indexingMaps = {AffineMap::get(7, 0, inExprs, context),
AffineMap::get(7, 0, goExprs, context),
AffineMap::get(7, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction,
IT::reduction};
} else {
AffineExpr f, c, kd, kh, kw, n, od, oh, ow;
bindDims(context, f, c, kd, kh, kw, n, od, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> inExprs = {
n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow};
SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow};
SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw};
indexingMaps = {AffineMap::get(9, 0, inExprs, context),
AffineMap::get(9, 0, goExprs, context),
AffineMap::get(9, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction, IT::reduction};
}
} else {
if (numSpatialDims == 1) {
AffineExpr g, fg, cg, k, n, o;
bindDims(context, g, fg, cg, k, n, o);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o};
SmallVector<AffineExpr> goExprs = {n, g, fg, o};
SmallVector<AffineExpr> outExprs = {g, fg, cg, k};
indexingMaps = {AffineMap::get(6, 0, inExprs, context),
AffineMap::get(6, 0, goExprs, context),
AffineMap::get(6, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr g, fg, cg, kh, kw, n, oh, ow;
bindDims(context, g, fg, cg, kh, kw, n, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh,
d1 * kw + s1 * ow};
SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow};
SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw};
indexingMaps = {AffineMap::get(8, 0, inExprs, context),
AffineMap::get(8, 0, goExprs, context),
AffineMap::get(8, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction};
} else {
AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow;
bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> inExprs = {
n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow};
SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow};
SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw};
indexingMaps = {AffineMap::get(10, 0, inExprs, context),
AffineMap::get(10, 0, goExprs, context),
AffineMap::get(10, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction, IT::reduction,
IT::reduction};
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There must be a better way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g., you could make the AffineExprs for stride, dilation, spatial dims, etc. SmallVector<AffineExpr>. I don't even think there need to be conditionals on anything other than like:

SmallVector<AffineExpr> lhsExprs = isGrouped ? {n, g, c} : {n, c};
// loop over spatial dims and add expressions...

Everything else can be like:

int64_t numIterators = 3; // batch, parallel channel, reduction channel
numIterators += static_cast<int64_t>(isGrouped);
numIterators += numSpatialDims*2 // parallel spatial dims, reduction spatial dims
indexingMaps = {
    AffineMap::get(numIterators, lhsExprs, context),
    AffineMap::get(numIterators, rhsExprs, context),
    AffineMap::get(numIterators, outExprs, context)
};

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants